#---------------------------------------------------------
# Are Intermediary Constraints Priced?
# Du, Hébert, Huber
# The Review of Financial Studies 2022
#---------------------------------------------------------

#-----------------------------------------------
# Functions for forward arbitrage
# This code constructs all functions to be called in analysis scripts.
# This version simplifies for online publication
#-----------------------------------------------

# ----------------------------------------------
# Libraries and constants
# ----------------------------------------------

library(haven)
library(readxl)
library(zoo)
library(gmm)
library(xtable)
library(lubridate)
library(ggplot2)
library(data.table)

#--------------------------------------

start_date <- ymd('2003-01-01')
end_date <- ymd('2018-07-31')
trading_end_date <- ymd('2020-11-30')
extended_end_date <- ymd('2020-12-31')
start_crisis_date <- ymd('2007-07-01')
post_crisis_date <- ymd('2010-07-01')
r_origin <- '1970-01-01'
excel_origin <- '1899-12-30'
stata_origin <- '1960-01-01'
number_of_periods <- 3
covid_start_date <- ymd('2020-01-01')
covid_order <- c('Pre-2020', 'Post-2020')
slr_start_date <- ymd('2015-01-01')
slr_order <- c('Pre-2015', 'Post-2015')

#--------------------------------------

generate_currency_names <- function(number_of_currencies) {

  # currency_names is a key parameter in all files.  Generates currency_names depending on the number of currency needed.
  #
  # Args:
  #  number_of_currency: numeric
  
  if (number_of_currencies == 10) {
    currency_names <- c('AUD', 'CAD', 'CHF', 'DKK', 'EUR', 'GBP', 'JPY', 'NOK', 'NZD', 'SEK')    
  } else if (number_of_currencies == 7) {
    currency_names <- c('AUD', 'CAD', 'CHF', 'EUR', 'GBP', 'JPY', 'NZD')    
  } else if (number_of_currencies == 6) {
    currency_names <- c('AUD', 'CAD', 'GBP', 'EUR', 'CHF', 'JPY') #in decreasing order of 1M-f 3M returns
  } else if (number_of_currencies == 5) {
    currency_names <- c('AUD', 'CAD', 'EUR', 'GBP', 'JPY')
  } else {
    stop('Invalid number of currencies.  Please try 5, 6, 7, or 10.', call. = FALSE)
  }
  
  return(currency_names)
}

#--------------------------------------

tenor_lookup <- data.frame(
  tenor_month = c(seq.int(from = 0, to = 11, by = 1), as.vector(t(sweep(outer(1:9, c(12, 12, 12), '*'), 2, c(0, 1, 3), '+'))),
                  c(10, 11, 12, 15, 20, 25, 30, 40) * 12),
  tenor_name = c('spot', 'one_mth', 'two_mth', 'three_mth', 'four_mth', 'five_mth', 'six_mth', 'seven_mth', 'eight_mth', 'nine_mth', 'ten_mth', 'eleven_mth', 
                 'one_yr', 'thirteen_mth', 'fifteen_mth',
                 'two_yr', 'twentyfive_mth', 'twentyseven_mth',
                 'three_yr', 'thirtyseven_mth', 'thirtynine_mth',
                 'four_yr', 'fourtynine_mth', 'fiftyone_mth',
                 'five_yr', 'sixtyone_mth', 'sixtythree_mth',
                 'six_yr', 'seventythree_mth', 'seventyfive_mth',
                 'seven_yr', 'eightyfive_mth', 'eightyseven_mth',
                 'eight_yr', 'ninetyseven_mth', 'ninetynine_mth',
                 'nine_yr', 'oneohnine_mth', 'oneeleven_mth', 'ten_yr', 
                 'eleven_yr', 'twelve_yr', 'fifteen_yr', 'twenty_yr',
                 'twentyfive_yr', 'thirty_yr', 'forth_yr'))
tenor_lookup$tenor_short_name <- c('spot', paste0(seq.int(from = 1, to = 11, by = 1), 'M'), 
                                   '1Y', '13M', '15M', '2Y', '25M', '27M', '3Y', '37M', '39M',
                                   '4Y', '49M', '51M', '5Y', '61M', '63M', '6Y', '73M', '75M',
                                   '7Y', '85M', '87M', '8Y', '97M', '99M', '9Y', '109M', '111M', 
                                   '10Y', '11Y', '12Y', '15Y', '20Y', '25Y', '30Y', '40Y')
tenor_lookup$tenor_name <- as.character(tenor_lookup$tenor_name)

#--------------------------------------

currency_convention <- data.frame(matrix(, nrow = 21, ncol = 3))
rownames(currency_convention) <- c('USD', paste(generate_currency_names(number_of_currencies = 10), 'ois', sep = '_'),
                                   paste(generate_currency_names(number_of_currencies = 10), 'ibor', sep = '_'))
names(currency_convention) <- c('day_count', 'fwd_factor', 'base_currency')
currency_convention['USD', ] <- c(360, 10000, 'USD')

currency_convention['AUD_ois', ] <- c(365, 10000, 'AUD')
currency_convention['CAD_ois', ] <- c(365, 10000, 'USD')
currency_convention['CHF_ois', ] <- c(360, 10000, 'USD')
currency_convention['DKK_ois', ] <- c(360, 10000, 'USD')
currency_convention['EUR_ois', ] <- c(360, 10000, 'EUR')
currency_convention['GBP_ois', ] <- c(365, 10000, 'GBP')
currency_convention['JPY_ois', ] <- c(365, 100, 'USD')
currency_convention['NOK_ois', ] <- c(360, 10000, 'USD')
currency_convention['NZD_ois', ] <- c(365, 10000, 'NZD')
currency_convention['SEK_ois', ] <- c(360, 10000, 'USD')

currency_convention['AUD_ibor', ] <- c(365, 10000, 'AUD')
currency_convention['CAD_ibor', ] <- c(365, 10000, 'USD')
currency_convention['CHF_ibor', ] <- c(360, 10000, 'USD')
currency_convention['DKK_ibor', ] <- c(360, 10000, 'USD')
currency_convention['EUR_ibor', ] <- c(360, 10000, 'EUR')
currency_convention['GBP_ibor', ] <- c(365, 10000, 'GBP')
currency_convention['JPY_ibor', ] <- c(360, 100, 'USD')
currency_convention['NOK_ibor', ] <- c(360, 10000, 'USD')
currency_convention['NZD_ibor', ] <- c(365, 10000, 'NZD')
currency_convention['SEK_ibor', ] <- c(360, 10000, 'USD')

currency_convention$day_count <- as.integer(currency_convention$day_count)
currency_convention$fwd_factor <- as.integer(currency_convention$fwd_factor)

# ----------------------------------------------
ois_table <- data.frame(fund_currency = c('AUD', 'AUD', 'USD', 'USD', 'CAD', 'AUD', 'GBP', 'USD', 'CAD', 'GBP'),
                        invest_currency = c('CHF', 'JPY', 'CHF', 'JPY', 'CHF', 'EUR', 'CHF', 'EUR', 'JPY', 'JPY'),
                        stringsAsFactors = F)

ibor_table <- data.frame(fund_currency = c('AUD', 'AUD', 'AUD', 'AUD', 'USD', 'USD', 'USD', 'USD', 'AUD', 'GBP'),
                         invest_currency = c('JPY', 'EUR', 'CHF', 'CAD', 'JPY', 'EUR', 'CHF', 'CAD', 'GBP', 'JPY'),
                         stringsAsFactors = F)

single_table <- data.frame(fund_currency = generate_currency_names(number_of_currencies = 6),
                           invest_currency = rep('USD', 6),
                           stringsAsFactors = F)

# ----------------------------------------------
# Utility functions
# ----------------------------------------------

comb <- function(x, ...) {
  
  # Customize combination function to be used in foreach of doParallel.  Takes output from foreach, which is a list of n objects, transposes into n lists.
  lapply(seq_along(x),
         function(i) c(x[[i]], lapply(list(...), function(y) y[[i]])))
}

multiple_reduced_join <- function(merge_key, merge_all_x, merge_all_y, data_list) {
  
  # Use fast Reduce to join elements in a list of length > 4, because R renames duplicate columns .x and .y max twice,
  # So we need an internal counter
  #
  # Args:
  #  merge_key: in string, name of variable to be merged on
  #  merge_all_*: TRUE or FALSE for merging options
  #  data_list: name of list containing data
  
  joined_data_table <- Reduce((function() {counter <- 0
  function(x, y) {
    # Global update
    counter <<- counter + 1
    d <- merge(x, y, all.x = merge_all_x, all.y = merge_all_y, by = merge_key)
    setnames(d, c(names(d)[(length(names(x)) + 1):length(names(d))]), paste(names(y)[2:length(names(y))], counter, sep = '.'))
    # The outer function creates an environment, and returns the inner function, () extract
  }})(), data_list)
  
  return(joined_data_table)
}

#--------------------------------------

generate_period_order <- function(number_of_periods) {
  
  # Generate period names based on the number of periods we want to have
  # Args: integer
  
  if (number_of_periods == 3) {
    period_order <- c('Pre-Crisis', 'Crisis', 'Post-Crisis')
  } else if (number_of_periods == 5) {
    period_order <- c('Normal', 'Pre-Crisis', 'Crisis', 'Post-Crisis', 'Post-Dodd Frank')  
  } else if (number_of_periods == 6) {
    period_order <- c('Normal', 'Pre-Crisis', 'Crisis', 'Post-Crisis', 'Post-Dodd Frank', 'Post-Basel III')  
  } else {
    stop('Invalid input, please enter 3, 5, or 6.')
  }
  return(period_order)
}
period_order <- generate_period_order(number_of_periods = number_of_periods)

#--------------------------------------

quarter_start <- function(data_table, date_column) {
  
  # Converts a column of 'QX YYYY' to the first date of the quarter
  # Args: 
  #  data_table: name of the data_table, NOT in quote
  #  date_column: STRING, name of column containing dates in 'QX YYYY' format
  
  data_table[, c('Quarter', 'Year') := tstrsplit(get(date_column), ' ', fixed = TRUE)]
  data_table[, numeric_quarter := ifelse(Quarter == 'Q1', '01-01', 
                                         ifelse(Quarter == 'Q2', '04-01',
                                                ifelse(Quarter == 'Q3', '07-01',
                                                       ifelse(Quarter == 'Q4', '10-01', NA))))]
  data_table[, (date_column) := ymd(paste(Year, numeric_quarter, sep = '-'))]
  data_table[, c('Quarter', 'Year', 'numeric_quarter') := NULL]
  
  return(data_table)
}

#--------------------------------------

month_to_quarter <- function(month_number) {
  
  # Converts a numeric number that represents month to the appropriate quarter
  # Args: 
  #  month_number: numeric
  
  quarter_number <- ifelse(month_number <= 3, 'Q1', 
                           ifelse((month_number > 3 & month_number <= 6), 'Q2',
                                  ifelse((month_number > 6 & month_number <= 9), 'Q3',
                                         ifelse((month_number > 9 & month_number <= 12), 'Q4', NA))))
  
  return(quarter_number)
}

#--------------------------------------

date_to_QE <- function(date_column) {
  
  # Find the QE of a date
  # Args: 
  #  date_column: lubridate format
  
  quarter_number <- quarter(date_column)
  year_number <- year(date_column)
  QE <- ifelse(quarter_number == 1, ymd(paste(year_number, '03-31', sep = '-')),
               ifelse(quarter_number == 2, ymd(paste(year_number, '06-30', sep = '-')),
                      ifelse(quarter_number == 3, ymd(paste(year_number, '09-30', sep = '-')),
                             ifelse(quarter_number == 4, ymd(paste(year_number, '12-31', sep = '-')), NA))))
  
  return(as.Date(QE))
}

#--------------------------------------

create_sample_period <- function(number_of_periods, data_table, date_variable, period_order) {
  
  # Categorize date into sub periods.
  # 
  # Args: 
  #  number_of_periods: num, how granular to divide up periods.
  #  data_table: name of data table, NOT in string.
  #  date_variable: string, name of the column containing dates.
  #  period_order: string, name of the different periods.
  #
  # Output:
  #  data table with new column "period".
  
  start_pre_crisis <- ymd('2006-01-01')
  start_crisis <- ymd('2007-07-01')
  start_post_crisis <- ymd('2009-05-01')
  start_post_Dodd_Frank <- ymd('2010-07-01')
  start_Basel_III <- ymd('2015-01-01')
  
  if (!(number_of_periods == 3 | number_of_periods == 5 | number_of_periods == 6)) {
    stop('Invalid number of periods.  Please try 3, 5, or 6.', call. = FALSE)
  } else if (number_of_periods == 3) {
    data_table[, period := ifelse(get(date_variable) < start_date, NA,
                                  ifelse(get(date_variable) < start_crisis, period_order[1],
                                         ifelse(get(date_variable) < start_post_Dodd_Frank, period_order[2], period_order[3])))]
  } else if (number_of_periods == 5) {
    data_table[, period := ifelse(get(date_variable) < start_pre_crisis, period_order[1],
                               ifelse(get(date_variable) < start_crisis, period_order[2],
                                      ifelse(get(date_variable) < start_post_crisis, period_order[3],
                                             ifelse(get(date_variable) < start_post_Dodd_Frank, period_order[4], period_order[5]))))]
  } else if (number_of_periods == 6) {
    data_table[, period := ifelse(get(date_variable) < start_pre_crisis, period_order[1],
                               ifelse(get(date_variable) < start_crisis, period_order[2],
                                      ifelse(get(date_variable) < start_post_crisis, period_order[3],
                                             ifelse(get(date_variable) < start_post_Dodd_Frank, period_order[4], 
                                                    ifelse(get(date_variable) < start_Basel_III, period_order[5], period_order[6])))))]
  }
  return(data_table)
}

#--------------------------------------

make_stars <- function(values, df) {
  
  # values: vector of (t-stat) values to be compared
  # df: vector of df corresponding to t-stat; assuming df adjustment of only 1
  sig_values <- sapply(df, function(x) c(qt(0.995, x - 1), qt(0.975, x - 1), qt(0.95, x - 1)))
  sig_stars <- ifelse(abs(values) >= sig_values[1, ], '***', ifelse(abs(values) >= sig_values[2, ], '**', ifelse(abs(values) >= sig_values[3, ], '*', '')))
  return(sig_stars)
}

#--------------------------------------

correct_outliers <- function(missing_or_correct, ...) {
  
  # Correct manually identified outliers in one of two ways:
  # (1) set all to NA / missing
  # (2) set all to values obtained via other sources (CMPN, BGN)
  # 
  # Args: 
  #  missing_or_correct: string, indicate the correction option
  #
  # Output:
  #  data table with outliers corrected
  
  if (!(missing_or_correct == 'missing' | missing_or_correct == 'correct')) {
    stop('Invalid method to correct outliers.  Please try \'missing\' or \'correct\'.', call. = FALSE)
  } else if (missing_or_correct == 'missing') {
    
    fwd_AUD$two_mth[which(fwd_AUD$Date == ymd('2007-06-29'))] <- NA
    fwd_CAD$four_mth[which(fwd_CAD$Date == ymd('2005-09-06'))] <- NA
    fwd_CAD$four_mth[which(fwd_CAD$Date == ymd('2013-12-26'))] <- NA
    fwd_CAD$four_mth[which(fwd_CAD$Date == ymd('2015-02-05'))] <- NA
    fwd_CAD$ten_mth[which(fwd_CAD$Date == ymd('2016-10-27'))] <- NA
    # fwd_DKK$one_mth[which(fwd_DKK$Date == ymd('2012-04-10'))] <- NA
    # fwd_NZD$three_mth[which(fwd_NZD$Date == ymd('2015-06-30'))] <- NA
    # fwd_NZD$four_mth[which(fwd_NZD$Date == ymd('2015-06-30'))] <- NA
    # fwd_NZD$four_mth[which(fwd_NZD$Date == ymd('2004-08-09'))] <- NA
    # fwd_SEK$two_mth[which(fwd_SEK$Date == ymd('2008-01-18'))] <- NA
    # fwd_SEK$four_mth[which(fwd_SEK$Date == ymd('2011-12-22'))] <- NA
    fwd_CHF$four_mth[which(fwd_CHF$Date == ymd('2017-09-05'))] <- NA
    
    ois_AUD$four_mth[which(ois_AUD$Date == ymd('2002-04-08'))] <- NA
    ois_CHF$two_mth[which(ois_CHF$Date == ymd('2007-07-05'))] <- NA
    # ois_DKK$six_mth[which(ois_DKK$Date == ymd('2013-09-11'))] <- NA
    # ois_NZD$two_mth[which(ois_NZD$Date == ymd('2011-04-06'))] <- NA
    # ois_NZD$four_mth[which(ois_NZD$Date == ymd('2014-10-27'))] <- NA
    # ois_NZD$two_mth[which(ois_NZD$Date == ymd('2011-04-26'))] <- NA
    
  } else if (missing_or_correct == 'correct') {
    
    fwd_AUD$two_mth[which(fwd_AUD$Date == ymd('2007-06-29'))] <- -14.4
    fwd_CAD$four_mth[which(fwd_CAD$Date == ymd('2005-09-06'))] <- -37.2
    fwd_CAD$four_mth[which(fwd_CAD$Date == ymd('2013-12-26'))] <- 31.67
    fwd_CAD$four_mth[which(fwd_CAD$Date == ymd('2015-02-05'))] <- 18.88
    # fwd_DKK$one_mth[which(fwd_DKK$Date == ymd('2012-04-10'))] <- -15.36
    # fwd_NZD$three_mth[which(fwd_NZD$Date == ymd('2015-06-30'))] <- -53.2
    # fwd_NZD$four_mth[which(fwd_NZD$Date == ymd('2015-06-30'))] <- -68.43
    # fwd_NZD$four_mth[which(fwd_NZD$Date == ymd('2004-08-09'))] <- -97.75
    # fwd_SEK$two_mth[which(fwd_SEK$Date == ymd('2008-01-18'))] <- 42
    # fwd_SEK$four_mth[which(fwd_SEK$Date == ymd('2011-12-22'))] <- 307
    fwd_CHF$four_mth[which(fwd_CHF$Date == ymd('2017-09-05'))] <- -80.8
    
    ois_AUD$four_mth[which(ois_AUD$Date == ymd('2002-04-08'))] <- 4.57
    ois_CHF$two_mth[which(ois_CHF$Date == ymd('2007-07-05'))] <- 2.59
    # ois_NZD$two_mth[which(ois_NZD$Date == ymd('2011-04-06'))] <- 2.51
    # ois_NZD$four_mth[which(ois_NZD$Date == ymd('2014-10-27'))] <- 3.52
    # ois_NZD$two_mth[which(ois_NZD$Date == ymd('2011-04-26'))] <- 2.52
    
  }
  
  list_to_return <- list(fwd_AUD, fwd_CAD, fwd_CHF, ois_AUD, ois_CHF)
  names(list_to_return) <- c('fwd_AUD', 'fwd_CAD', 'fwd_CHF', 'ois_AUD', 'ois_CHF')
  return(list_to_return)
}  

#--------------------------------------

days_with_complete_info <- function(data_list, currency_names, column_to_index) {
  
  # Returns days when ALL specified currencies in the original data list have a value for the desired column.
  # Args:
  #  data_list: list, original data, typically Data_one_mth_in_X_mth...
  #  currency_names: the currencies that need to have entries.
  #  column_to_index: the column whose value is important, i.e. 'PI'.
  
  merge_list <- list()
  for (l in 1:length(currency_names)) {
    merge_list[[l]] <- data_list[[currency_names[l]]][, c('Date', column_to_index), with = FALSE]
  }
  
  temp <- multiple_reduced_join(merge_key = 'Date', merge_all_x = TRUE, merge_all_y = TRUE, data_list = merge_list)
  names(temp) <- c('Date', currency_names)
  
  temp[, pi_available := apply(.SD, 1, function(x) sum(!is.na(x))), .SDcols = currency_names]
  days_to_keep <- temp[pi_available == length(currency_names), 'Date']

  return(days_to_keep)
}  

#--------------------------------------

compute_compound_interest <- function(rate_annualized, day_count, duration) {
  
  # Calculates compound interest for specified duration from annualized rate.
  # Args:
  #  rate_annualized: double, annualized rate
  #  day_count: integer, denominator of act/day_count
  #  duration: integer, days to maturity, numerator of act/day_count
  
  compounded_interest <- (duration * rate_annualized / day_count / 100 + 1)
  return(compounded_interest)
}

#--------------------------------------

compute_annulized_rate <- function(compounded_interest, day_count, duration) {
  
  # Calculates the annualized rate needed to get the compounded interest.
  # Args: compounded interest: double, result from (1 + rate)^duration
  
  annualized_rate <- (compounded_interest - 1) * day_count * 100 / duration
  return(annualized_rate)
}

#--------------------------------------

compute_log_currency_basis <- function(domestic_annualized_rate, foreign_annualized_rate, start_exchange_rate, end_exchange_rate, 
                                       domestic_day_count, foreign_day_count, duration) {

  # Calculates the log currency basis: analogous to Du et al.
  # Args: 
  #  *_annualized_rate: numeric, stated in percentages, e.g. 2% per annum -> 2
  #  start_exchange_rate: numeric, at the start of the contract, e.g. spot if starting today
  #  duration: in days
  
  annualized_synthetic_dollar_interest <- domestic_day_count / duration * (log(1 + foreign_annualized_rate / 100 * duration / foreign_day_count) - 
                                                                      log(end_exchange_rate / start_exchange_rate))
  annualized_actual_dollar_interest <- domestic_day_count / duration * log(1 + domestic_annualized_rate / 100 * duration / domestic_day_count)
  annualized_log_basis <- (annualized_actual_dollar_interest - annualized_synthetic_dollar_interest) * 100
  
  return(annualized_log_basis)
}

#--------------------------------------

compute_log_fund_invest_basis <- function(fund_annualized_rate, invest_annualized_rate, 
                                       fund_start_exchange_rate, fund_end_exchange_rate, 
                                       invest_start_exchange_rate, invest_end_exchange_rate, 
                                       USD_day_count, fund_day_count, invest_day_count, 
                                       fund_duration, invest_duration) {
  
  # Calculates the log of borrow fund synthetic dollar and lend invest synthetic dollar basis: analogous to Du et al.
  # NOTE: can  be used to calc interest for any currency pair where the base is not USD.
  # Args: 
  #  *_annualized_rate: numeric, stated in percentages, e.g. 2% per annum -> 2
  #  *_exchange_rate: numeric
  #  duration: in days
  
  annualized_fund_synthetic_dollar_interest <- USD_day_count / fund_duration * (log(1 + fund_annualized_rate / 100 * fund_duration / fund_day_count) - 
                                                                             log(fund_end_exchange_rate / fund_start_exchange_rate))
  annualized_invest_synthetic_dollar_interest <- USD_day_count / invest_duration * (log(1 + invest_annualized_rate / 100 * invest_duration / invest_day_count) - 
                                                                                 log(invest_end_exchange_rate / invest_start_exchange_rate))
  annualized_log_basis <- (annualized_invest_synthetic_dollar_interest - annualized_fund_synthetic_dollar_interest) * 100
  
  return(annualized_log_basis)
}

#--------------------------------------

linear_imputation <- function(start_level, end_level, period_in_between, period_to_impute) {
  
  # Imputes spot basis OR interest rates linearly from start_level to end_level based on number of periods_in_between and period_to_impute
  # Args: 
  #  start_level: column of price (basis, rates) for the start
  #  end_level: column of price (basis, rates) for the end
  #  period_in_between: numeric, in months
  #  period_to_impute: numeric, months since start_period
  
  imputed_level <- start_level + (end_level - start_level) / period_in_between * period_to_impute
  return(imputed_level)
}

#-----------------------------------
# Analysis functions
#-----------------------------------

moment_function <- function(parameter, dat) {
  
  # Generate the two moment conditions for mean and sample standard deviation.
  # Args:
  #  parameter: vector of parameters
  #  dat: data as a vector
  
  N <- length(dat)
  m1 <- parameter[1] - dat
  m2 <- parameter[2]^2 - ((dat - parameter[1])^2) * N / (N - 1)
  moment <- cbind(m1, m2)
  return(moment)
}

#--------------------------------------

gradient_function <- function(parameter, dat) {
  
  # Generate the gradient of the moment function above.
  # Args:
  #  parameter: vector of parameters
  #  dat: data as a vector
  
  G <- matrix(c(1, -2 * (mean(dat) - parameter[1]), 0, 2 * parameter[2]), ncol = 2)
  return(G)
}

#--------------------------------------

compute_top_weight <- function(portfolio_data, type_of_reference, top_number, currency_names, currency_values, interest_reference = NULL) {
  
  # Compute weights based on difference between either 3M FGN and 3M USD interest or 3M spot basis.
  # Equal weight among top 5; if tie, take only 5.
  # Args:
  #  portfolio_data: data table, name of the data table containing all columns.
  #  type_of_reference: string, 'basis' (spot log 3M basis) OR 'interest' (funding currency 3M interest - investing currency 3M interest).
  #  top_number: numeric, number of top pairs, no tie.
  #  currency_names: string, number of currency pairs.
  #  currency_values: string, names of columns containing 3M basis or funding currency 3M interest.
  #  interest_reference: string, applicable only for "interest", name of investing currency 3M interest.

  difference_names <- paste(currency_names,  'difference', sep = '_')
  column_names <- paste(currency_names, 'column', sep = '_')
  rank_names <- paste('rank', seq.int(1, top_number), sep = '_')
  weight_names <- paste(currency_names, 'weight', sep = '_')
  
  if (!(type_of_reference == 'basis' | type_of_reference == 'interest')) {
    stop('Invalid type of reference: \'basis\' or \'interest\'.', call. = FALSE)
  } else if (type_of_reference == 'basis') {
    portfolio_data[, (difference_names) := lapply(.SD, function(x) x), .SDcols = currency_values]
  } else if (type_of_reference == 'interest') {
    for (i in 1:length(difference_names)) {
      portfolio_data[, difference_names[i] := .SD[, 1] - .SD[, 2], .SDcols = c(currency_values[i], interest_reference[i])]
    }
  }
  
  # View(portfolio_data[, c('Date', currency_values, interest_reference, difference_names), with = F])
  rank_order <- data.table(t(apply(portfolio_data[, (difference_names), with = F], 1, function(x) order(x, decreasing = T)[1:top_number])))
  names(rank_order) <- rank_names
  column_order <- data.table(matrix(rep(seq.int(1:length(column_names)), each = nrow(portfolio_data)), ncol = length(column_names)))
  names(column_order) <- column_names
  
  temp_intermediate <- cbind(column_order, rank_order)
  temp_weight <- data.table(t(apply(temp_intermediate[, c(column_names, rank_names), with = F], 1, 
                                    function(x) as.numeric(x[column_names] %in% x[rank_names]) / top_number)))
  names(temp_weight) <- weight_names
  portfolio_data <- cbind(portfolio_data, temp_weight)
  return(portfolio_data)
}

#--------------------------------------

compute_dynamic_carry_weight <- function(portfolio_data, type_of_reference, currency_values, reference_values) {
  
  # Compute weights based on difference between either 3M FGN and 3M USD interest or 3M log forward premium.
  # Args:
  #  portfolio_data: data table, name of the data table containing all columns.
  #  type_of_reference: string, 'forward' (log 3M forward - log spot) OR 'interest' (foreign 3M interest - USD 3M interest).
  #  currency_values: string, names of columns containing 3M forward rates or foreign 3M interest.
  #  reference_values: string, names of columns containing 3M spot rates or USD 3M interest.
  
  difference_names <- paste(substr(currency_values, 0, 3), 'difference', sep = '_')
  relative_to_avg_names <- paste(substr(currency_values, 0, 3), 'relative_to_average', sep = '_')
  weight_names <- paste(substr(currency_values, 0, 3), 'weight', sep = '_')
  
  if (!(type_of_reference == 'forward' | type_of_reference == 'interest')) {
    stop('Invalid type of reference: \'forward\' or \'interest\'.', call. = FALSE)
  } else if (type_of_reference == 'forward') {
    for (i in 1:length(difference_names)) {
      portfolio_data[, difference_names[i] := log(.SD[, 1]) - log(.SD[, 2]), .SDcols = c(currency_values[i], reference_values[i])]
    }
  } else if (type_of_reference == 'interest') {
    portfolio_data[, (difference_names) := lapply(.SD, function(x) x - get(reference_values)), .SDcols = currency_values]
  }

  portfolio_data[, average_reference := apply(.SD, 1, function(x) mean(x, na.rm = TRUE)), .SDcols = difference_names]
  portfolio_data[, (relative_to_avg_names) := lapply(.SD, function(x) x - average_reference), .SDcols = difference_names]
  portfolio_data[, long_reference := apply(.SD, 1, function(x) sum(x[x >= 0], na.rm = TRUE)), .SDcols = relative_to_avg_names]
  portfolio_data[, short_reference := apply(.SD, 1, function(x) sum(x[x < 0], na.rm = TRUE)), .SDcols = relative_to_avg_names]
  portfolio_data[, (weight_names) := lapply(.SD, function(x) ifelse(x >= average_reference, x / long_reference,
                                                                    -x / short_reference)), .SDcols = relative_to_avg_names]
  
  return(portfolio_data)
}


#--------------------------------------

compute_equal_basis_weight <- function(portfolio_data, currency_names, type_of_rate) {
  
  # Compute weights based on the sign of spot 3M basis; equal weight of 1/n
  # Args:
  #  portfolio_data: data table, name of the data table containing all columns.
  #  currency_names: string, names of currencies.
  #  type_of_rate: string, 'ois' or 'ibor'.
  
  number_of_currencies <- length(currency_names)
  sign_names <- paste(currency_names, 'basis_sign', sep = '_')
  weight_names <- paste(currency_names, 'weight', sep = '_')
  
  portfolio_data[, (sign_names) := lapply(.SD, function(x) ifelse(x >= 0, 1, -1)), .SDcols = paste(currency_names, type_of_rate, 'spot_log_basis_3M', sep = '_')]
  portfolio_data[, (weight_names) := lapply(.SD, function(x) x / number_of_currencies), .SDcols = sign_names]
  
  return(portfolio_data)
}

#--------------------------------------

compute_equal_spread_weight <- function(portfolio_data, type_of_rate, currency_names, spread_top_name, spread_bottom_name) {
  
  # Compute weights based on the sign of difference between forward and spot basis; equal weight of 1/n
  # Args:
  #  portfolio_data: data table, name of the data table containing all columns.
  #  type_of_rate: string, 'ois' or 'ibor'.
  #  currency_names: string, names of currencies.
  #  spread_top_name: string, 'log_basis_1M_fwd_3M' or 'spot_log_basis_3M' or 'log_future_basis_3M_in_1M'.
  #  spread_bottom_name: string, 'log_basis_1M_fwd_3M' or 'spot_log_basis_3M' or 'log_future_basis_3M_in_1M'.
  
  number_of_currencies <- length(currency_names)
  sign_names <- paste(currency_names, 'spread_sign', sep = '_')
  weight_names <- paste(currency_names, 'weight', sep = '_')
  
  for (i in 1:length(sign_names)) {
    portfolio_data[, sign_names[i] := ifelse(.SD[, 1] - .SD[, 2] >= 0, 1, -1), 
                   .SDcols = c(paste(currency_names[i], type_of_rate, spread_top_name, sep = '_'),
                               paste(currency_names[i], type_of_rate, spread_bottom_name, sep = '_'))]
  }
  portfolio_data[, (weight_names) := lapply(.SD, function(x) x / number_of_currencies), .SDcols = sign_names]
  
  return(portfolio_data)
}

#--------------------------------------

compute_carry_neutral_weight <- function(portfolio_data, type_of_rate, dollar_weight_names, carry_weight_names, currency_values, reference_values) {
  
  # Compute weights as a linear combo of equal dollar weight and classic carry weight.  Key is to make interest diff to USD 0.
  # Args:
  #  portfolio_data: data table, name of the data table containing all columns.
  #  type_of_rate: string, 'ois' or 'ibor'.
  #  dollar_weight_names: string, names of columns containing 'equal dollar' weights.
  #  carry_weight_names: string, names of columns containing 'classic carry' weights.
  #  currency_values: string, names of columns containing 3M forward rates or foreign 3M interest.
  #  reference_values: string, names of columns containing 3M spot rates or USD 3M interest.
  
  weight_names <- paste(currency_names, 'weight', sep = '_')
  
  portfolio_data[, dollar_diff := apply(portfolio_data[, (currency_values), with = FALSE] * portfolio_data[, (dollar_weight_names), with = FALSE],
                                        1, function(x) sum(x)) - get(reference_values)]
  portfolio_data[, carry_diff := apply(portfolio_data[, (currency_values), with = FALSE] * portfolio_data[, (carry_weight_names), with = FALSE],
                                       1, function(x) sum(x))]
  portfolio_data[, alpha_on_carry := (0 - dollar_diff) / carry_diff]

  for (i in 1:length(weight_names)) {
    portfolio_data[, weight_names[i] := .SD[, 1] + alpha_on_carry * .SD[, 2], 
                   .SDcols = c(dollar_weight_names[i], carry_weight_names[i])]
  }

  return(portfolio_data)
}

#--------------------------------------

summary_stats_basis <- function(temp_data_table, column_to_summ) {
  
  # Create mean and sd of the desired BASIS column after first removing rows with NA.
  # Note that basis are already annualized.
  # 
  # Args: 
  #  temp_data_table: data table containing the desired columns.
  #  column_to_summ: string, name of column to be summarized.
  #
  # Output:
  #  Data table of length 3.
  
  temp_portfolio <- temp_data_table[!is.na(get(column_to_summ)), ]
  temp_summary <- temp_portfolio[, list(n = sum(!is.na(get(column_to_summ))),
                                        mean = mean(get(column_to_summ)),
                                        sd_mean = sd(get(column_to_summ)))]
  
  return(temp_summary)
}

#-----------------------------------

summary_stats_returns <- function(temp_data_table, column_to_summ, mths_in_return, mths_in_contract, basis_or_percentage, NeweyWest = T, NW_bw = 40) {
  
  # Create mean and Sharpe ratio of the desired RETURN column, along with NW-SE, after first removing rows with NA.
  # Note that return and Sharpe ratios are annualized.
  # 
  # Args: 
  #  temp_data_table: data table containing the desired columns.
  #  column_to_summ: string, name of column to be summarized.
  #  mths_in_return: numeric, number of months in each return figure, equal to the months in forward / before spot starts.
  #  mths_in_contract: numeric, number of months in the interest rate, since basis are annualized, a 3M contract earns 3/12 annual basis, 
  #                    but then we annualize it by (12/months_in_return).
  #  basis_or_percentage: string, 'basis' OR 'percentage'.
  #  NW: Boolean, if false, then use iid in GMM.
  #  NW_bw: numeric, default to 40, as determined in Stata.
  #
  # Output:
  #  Data table of length 8.
  # 
  # Deprecated column_of_overlap for bandwidth selection and now uses Newey-West estimate.
  
  temp_portfolio <- temp_data_table[!is.na(get(column_to_summ)), ]
  annual_multiplier <- 12 / mths_in_return
  profit_scale <- mths_in_contract / mths_in_return
  if (basis_or_percentage == 'basis') {
    scale_multiplier <- 100
  } else if (basis_or_percentage == 'percentage') {
    scale_multiplier <- 1
  }
  temp_summary <- temp_portfolio[, list(n = sum(!is.na(get(column_to_summ))),
                                        mean = mean(get(column_to_summ), na.rm = T) * profit_scale * scale_multiplier,
                                        se_mean = NA,
                                        t_mean = NA,
                                        sd_mean = sd(get(column_to_summ), na.rm = T) * profit_scale * scale_multiplier,
                                        Sharpe = mean(get(column_to_summ), na.rm = T) / sd(get(column_to_summ), na.rm = T) * sqrt(annual_multiplier),
                                        se_Sharpe = NA,
                                        t_Sharpe = NA)]
  
  if (NeweyWest == T) {
    fit_gmm <- gmm(g = moment_function, x = temp_portfolio[, get(column_to_summ)], t0 = c(temp_summary$mean / profit_scale / scale_multiplier, temp_summary$sd_mean / profit_scale / scale_multiplier), 
                   type = c("twoStep"), prewhite = 0, gradv = gradient_function, wmatrix = c("ident"), vcov = c("HAC"), kernel = c("Bartlet"), bw = NW_bw, method = c('BFGS'))
  } else if (NeweyWest == F) {
    fit_gmm <- gmm(g = moment_function, x = temp_portfolio[, get(column_to_summ)], t0 = c(temp_summary$mean / profit_scale / scale_multiplier, temp_summary$sd_mean / profit_scale / scale_multiplier), 
                   type = c("twoStep"), prewhite = 0, gradv = gradient_function, wmatrix = c("ident"), vcov = c("iid"), method = c('BFGS'))
  }
  
  variance_matrix <- fit_gmm$vcov
  temp_summary[, 'se_mean'] <- sqrt(variance_matrix[1, 1]) * profit_scale * scale_multiplier
  temp_summary[, 't_mean'] <- temp_summary[, 'mean'] / temp_summary[, 'se_mean']
  gradient_Sharpe <- c(1 / sd(temp_portfolio[, get(column_to_summ)]), -mean(temp_portfolio[, get(column_to_summ)]) / sd(temp_portfolio[, get(column_to_summ)])^2)
  temp_summary[, 'se_Sharpe'] <- sqrt(t(gradient_Sharpe) %*% variance_matrix %*% gradient_Sharpe) * sqrt(annual_multiplier)
  temp_summary[, 't_Sharpe'] <- temp_summary[, 'Sharpe'] / temp_summary[, 'se_Sharpe']
  
  return(temp_summary)
}

#-----------------------------------

summary_stats_returns_NoSE <- function(temp_data_table, column_to_summ, mths_in_return) {
  
  # Create mean and Sharpe ratio of the desired RETURN column, along with NW-SE, after first removing rows with NA.
  # Note that return data are monthly (when forward is one-month out), so need to 
  # annualize return (and SE) by *12, and SR (and SE) by *sqrt(12).
  # 
  # Args: 
  #  temp_data_table: data table containing the desired columns.
  #  column_to_summ: string, name of column to be summarized.
  #  mths_in_return: numeric, number of months in each return figure, equal to the months in forward / before spot starts.
  #
  # Output:
  #  Data table of length 8.
  # 
  # Deprecated column_of_overlap for bandwidth selection and now uses Newey-West estimate.
  
  temp_portfolio <- temp_data_table[!is.na(get(column_to_summ)), ]
  annual_multiplier <- 12 / mths_in_return
  temp_summary <- temp_portfolio[, list(n = sum(!is.na(get(column_to_summ))),
                                        mean = mean(get(column_to_summ)) * annual_multiplier,
                                        sd_mean = sd(get(column_to_summ)) * sqrt(annual_multiplier),
                                        Sharpe = mean(get(column_to_summ)) / sd(get(column_to_summ)))]
  
  return(temp_summary)
}

#--------------------------------------------

find_ois_rate <- function(ois_dt, basis_dt, mat_required_name, rate_names, expiry_names) {
  
  # Interpolate appropriate OIS rate given required maturity.
  # Args:
  #  ois_dt: name of data table containing ois rates and expiry dates.
  #  basis_dt: name of data table containing trade dates in 'Date' and maturity required.
  #  mat_required_name: string, name of column containing mat_required. E.g. 'spot_front_imm'
  
  ois_dt[, mat_required := basis_dt[, (mat_required_name), with = F][match(ois_dt$Date, basis_dt$Date)]]
  ois_dt <- ois_dt[complete.cases(ois_dt)]
  ois_dt[, lower_index := apply(.SD, 1, function(x) max(which(x[1:length(expiry_names)] <= x[length(expiry_names) + 1]))), .SDcols = c(expiry_names, 'mat_required')]
  ois_dt[, lower_date := apply(.SD, 1, function(x) {i <- x[ length(expiry_names) + 1]
  return(ymd(x[as.numeric(i)]))}), .SDcols = c(expiry_names, 'lower_index')]
  ois_dt[, lower_rate := apply(.SD, 1, function(x) x[x[length(rate_names) + 1]]), .SDcols = c(rate_names, 'lower_index')]
  ois_dt[, lower_dur := as.numeric(mat_required - lower_date)]
  
  ois_dt[, upper_index := apply(.SD, 1, function(x) min(which(x[1:length(expiry_names)] > x[length(expiry_names) + 1]))), .SDcols = c(expiry_names, 'mat_required')]
  ois_dt[, upper_date := apply(.SD, 1, function(x) {i <- x[ length(expiry_names) + 1]
  return(ymd(x[as.numeric(i)]))}), .SDcols = c(expiry_names, 'upper_index')]
  ois_dt[, upper_rate := apply(.SD, 1, function(x) x[x[length(rate_names) + 1]]), .SDcols = c(rate_names, 'upper_index')]
  ois_dt[, total_dur := as.numeric(upper_date - lower_date)]
  
  ois_dt[, interpolated := (upper_rate - lower_rate) / total_dur * lower_dur + lower_rate]
  return(ois_dt[, c('Date', 'interpolated'), with = F])
}

#-----------------------------------
# Graph functions and themes
#-----------------------------------

theme_chart <- 
  theme(legend.position = 'bottom', legend.key = element_blank(), legend.key.size = unit(0.5, 'cm')) +
  theme(legend.text = element_text(size = 10)) + 
  theme(plot.title = element_text(size = 18, face = 'bold', hjust = 0.5, color = 'black')) +
  theme(plot.subtitle = element_text(size = 12, face = 'italic', hjust = 0.5, color = 'black')) +
  theme(axis.title = element_text(size = 12), axis.text = element_text(size = 10)) +
  theme(axis.title.x = element_text(margin = margin(t = 10, r = 0, b = 0, l = 0))) +
  theme(axis.line = element_line(colour = 'black')) +
  theme(panel.background = element_blank())

## Preferred theme_chart, both font family may not be universal, and so unless use extrafont::font_import(), better to use above
# theme_chart <- 
#   theme(legend.position = 'bottom', legend.key = element_blank(), legend.key.size = unit(0.5, 'cm')) +
#   theme(legend.text = element_text(size = 10, family = 'Arial')) + 
#   theme(plot.title = element_text(size = 18, family = 'Arial', face = 'bold', hjust = 0.5, color = 'black')) +
#   theme(plot.subtitle = element_text(size = 12, family = 'Arial', face = 'italic', hjust = 0.5, color = 'black')) +
#   theme(axis.title = element_text(size = 12, family = 'Arial'), axis.text = element_text(size = 10, family = 'Arial')) +
#   theme(axis.title.x = element_text(margin = margin(t = 10, r = 0, b = 0, l = 0))) +
#   theme(axis.line = element_line(colour = 'black')) +
#   theme(panel.background = element_blank())
